from src.models.alexnet import AlexNet
from src.models.simplenet import SimpleNet


__all__ = ['create_classification_net']


def create_classification_net(name, num_classes):
    if name == "alexnet":
        return AlexNet(num_classes=num_classes)
    elif name == "simplenet":
        return SimpleNet(num_classes=num_classes)
    else:
        raise ValueError("unsupported architecture.")
